2025.5.26 ADによる固有値指定
無理矢理感が強いが
eigval = eig(X)
L = norm(eigval - target_eigval)
L.backeard()
によって求めたdL/dXを使ってXを更新することで、Xの固有値をtargetと同じ値に近づける。
結論としては可能であった。自動微分凄い。
code:real_eig.py
import torch as pt
import matplotlib.pyplot as plt
x = pt.rand(3,3, requires_grad=True)
target_eig = pt.tensor(1,2,3, dtype=pt.float) optimizer = pt.optim.Adam(x, lr=0.001) hist_x = []
hist_loss = []
hist_eig = []
for i in range(100000):
optimizer.zero_grad()
eig, d = pt.linalg.eig(x)
loss = pt.linalg.norm(eig - target_eig)
loss.backward()
hist_loss.append(loss.clone().detach())
hist_eig.append(eig.flatten())
hist_x.append(x.flatten())
optimizer.step()
hist_loss = pt.stack(hist_loss)
hist_eig = pt.stack(hist_eig)
plt.plot(hist_loss)
plt.show()
hist_eigR = pt.real(hist_eig).clone().detach()
hist_eigI = pt.imag(hist_eig).clone().detach()
for i in range(3):
plt.plot(hist_eigR:,i, hist_eigI:,i, '.') for i in range(3):
plt.plot(hist_eigR0,i, hist_eigI0,i, '+', c='blue') plt.plot(hist_eigR-1,i, hist_eigI-1,i, '*', c='red') plt.show()
print(eig)
軌道はランダムに与えた行列の初期値に依存する。
lossの推移、横軸は計算回数、単調増加ではないので刻み時間の選定を含めて感度の高い計算を行っていることが伺える。
解釈:勾配はそもそもlossを減少させる向きを意味しているので、その方向にXを変化させるとlossが増加することは無いはず、学習率が大きいため、極小値を行き過ぎた地点までXを更新しているのではないか。
https://scrapbox.io/files/6834146a66edfdf3225ce260.png
極の軌道をプロットした複素平面、ヨコが実軸タテが虚軸、ウネウネと動いているが確かに初期値からtargetで指定した極まで移動している。
https://scrapbox.io/files/68341471f2fcd7c11431c2c0.png
まっすぐtargetに向かうわけではないようだ。
xの値
| 1.2934, 0.1626, -0.4755 |
| 0.1753, 2.2490, 0.6616 |
| -0.4914, 0.5350, 2.4583 |
複素極はどうなる?
code:complex.py
import torch as pt
import matplotlib.pyplot as plt
x = pt.rand(3,3, requires_grad=True)
optimizer = pt.optim.Adam(x, lr=0.001) hist_x = []
hist_loss = []
hist_eig = []
for i in range(100000):
optimizer.zero_grad()
eig, d = pt.linalg.eig(x)
loss = pt.linalg.norm(eig - target_eig)
loss.backward()
hist_loss.append(loss.clone().detach())
hist_eig.append(eig.flatten())
hist_x.append(x.flatten())
optimizer.step()
hist_loss = pt.stack(hist_loss)
hist_eig = pt.stack(hist_eig)
plt.plot(hist_loss)
plt.show()
hist_eigR = pt.real(hist_eig).clone().detach()
hist_eigI = pt.imag(hist_eig).clone().detach()
for i in range(3):
plt.plot(hist_eigR:,i, hist_eigI:,i, '.') for i in range(3):
plt.plot(hist_eigR0,i, hist_eigI0,i, '+', c='blue') plt.plot(hist_eigR-1,i, hist_eigI-1,i, '*', c='red') plt.show()
print(eig)
loss
https://scrapbox.io/files/6834136031c17cab37e868bc.png
極
https://scrapbox.io/files/6834136a37eca93a22a6d6d0.png
指定に近い値に配置されているようだ。
指定:
|1 + 1j, 1 - 1j, 3 |
結果:
| 0.9992+0.9999j, 0.9992-0.9999j, 3.0010+0.0000j |
xの値
| 1.2508, -0.9212, 1.6104 |
| 0.9452, 1.1456, 0.7740 |
| 0.1038, 0.4831, 2.6040 |